package org.bdware.analysis.dynamic;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.bdware.analysis.BasicBlock;
import org.bdware.analysis.BreadthFirstSearch;
import org.bdware.analysis.OpInfo;
import org.bdware.analysis.taint.TaintBB;
import org.bdware.analysis.taint.TaintCFG;
import org.bdware.analysis.taint.TaintResult;
import org.bdware.analysis.taint.TaintValue;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.JumpInsnNode;
import org.objectweb.asm.tree.LabelNode;
import org.objectweb.asm.tree.MethodNode;

import java.util.*;

public class FieldSensitiveDynamicTaintAnalysis extends BreadthFirstSearch<TaintResult, TaintBB> {
    private static final Logger LOGGER = LogManager.getLogger(FieldSensitiveDynamicTaintAnalysis.class);
    public static boolean isDebug = false;
    TaintCFG cfg;
    TracedFile tf;
    int count = 0;

    public FieldSensitiveDynamicTaintAnalysis(TaintCFG cfg, TracedFile tf) {
        this.cfg = cfg;
        this.tf = tf;
        List<TaintBB> toAnalysis = new ArrayList<>();
        MethodNode mn = cfg.getMethodNode();
        LOGGER.info("mn: " + mn.name);
        String methodDesc = mn.desc;
        LOGGER.info("method desc: " + methodDesc);
        methodDesc = methodDesc.replaceAll("\\).*$", ")");
        LOGGER.info("method desc: " + methodDesc);
        int pos = 2;
        if (methodDesc.split(";").length == 3)
            pos = 1;
        // TODO add inputBlock!
        LOGGER.info("p+++ " + pos);
        TaintBB b = (TaintBB) cfg.getBasicBlockAt(0);
        b.preResult = new TaintResult();
        LOGGER.info(TaintResult.nLocals);
        if (TaintResult.nLocals > 2) {
            for (int i = 0; i < TaintResult.nLocals; i++)
                b.preResult.frame.setLocal(i, new TaintValue(1, i == pos ? 1 : 0));
        }
        b.preResult.ret = new TaintValue(1);
        // NaiveTaintResult.printer.setLabelOrder(cfg.getLabelOrder());
        toAnalysis.add(b);
        b.setInList(true);
        setToAnalysis(toAnalysis);
        if (isDebug) {
            MethodNode methodNode = cfg.getMethodNode();
            LOGGER.info("Method: " + methodNode.name + " " + methodNode.desc);
            LOGGER.info("Local: " + methodNode.maxLocals + " " + methodNode.maxStack);
        }
    }

    @Override
    public TaintResult execute(TaintBB t) {
        return t.forwardAnalysis();
    }

    // Current Block is done, merge sucResult to sucBlocks!
    @Override
    public Collection<TaintBB> getSuc(TaintBB t) {
        Set<BasicBlock> subBlock = new HashSet<>();//
        if (t.list.size() > 0) {
            AbstractInsnNode insn = t.lastInsn();
            if (insn != null) {
                OpInfo info = null;
                if (insn.getOpcode() >= 0)
                    info = OpInfo.ops[insn.getOpcode()];
                if (info == null) {
                    subBlock.addAll(cfg.getSucBlocks(t));
                } else if (info.canThrow()) {
                    // subBlock.clear();
                    subBlock.add(cfg.getBasicBlockAt(t.blockID + 1));
                } else if (OpInfo.ops[insn.getOpcode()].canBranch()) {
//                    subBlock.clear();
//                    LOGGER.info("can branch " + t.blockID);
                    int block = handleBranchCase(subBlock, insn, t);
                    subBlock.add(cfg.getBasicBlockAt(block));
                } else if (OpInfo.ops[insn.getOpcode()].canSwitch()) {
                    subBlock.add(cfg.getBasicBlockAt(t.blockID + 1));
                } else {
                    subBlock.addAll(cfg.getSucBlocks(t));
                }
            }
        }
        Set<TaintBB> ret = new HashSet<>();
//        StringBuilder log = new StringBuilder("nextBB:");
        for (BasicBlock bb : subBlock) {
//            log.append("\n\t").append(bb.blockID);
            TaintBB ntbb = (TaintBB) bb;
            ntbb.preResult.mergeResult(t.sucResult);
            ret.add(ntbb);
        }
//        LOGGER.info(log.toString());
        return ret;
    }

    private int handleBranchCase(Set<BasicBlock> subBlock, AbstractInsnNode insn, TaintBB t) {
        int blockid = 0;
        switch (insn.getOpcode()) {
            case Opcodes.IFNE:// succeeds if and only if value ≠ 0
                count++;
                if (tf.trans.get(0).tmToVal.get(0).get(0) != 0) {
                    // test first tracemark{"traceMark":0, "val":1}
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFEQ:// succeeds if and only if value = 0

                if (tf.trans.get(0).tmToVal.get(0).get(0) == 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFGE:// succeeds if and only if value ≥ 0
                if (tf.trans.get(0).tmToVal.get(0).get(0) >= 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFLE:// succeeds if and only if value ≤ 0
                if (tf.trans.get(0).tmToVal.get(0).get(0) <= 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFLT:// succeeds if and only if value < 0
                if (tf.trans.get(0).tmToVal.get(0).get(0) < 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IFGT:// succeeds if and only if value > 0
                if (tf.trans.get(0).tmToVal.get(0).get(0) < 0) {
                    if (insn instanceof JumpInsnNode) {
                        LabelNode jump = ((JumpInsnNode) insn).label;
                        blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                    }
                } else {
                    blockid = t.blockID + 1;
                }
                return blockid;
            case Opcodes.IF_ACMPEQ:// succeeds if and only if value1 = value2
            case Opcodes.IF_ACMPNE:// succeeds if and only if value1 ≠ value2
            case Opcodes.IF_ICMPEQ:// succeeds if and only if value1 = value2
            case Opcodes.IF_ICMPGE:// succeeds if and only if value1 ≠ value2
            case Opcodes.IF_ICMPGT:// succeeds if and only if value1 > value2
            case Opcodes.IF_ICMPLE:// succeeds if and only if value1 ≤ value2
            case Opcodes.IF_ICMPLT:// succeeds if and only if value1 < value2

            case Opcodes.IFNONNULL:
            case Opcodes.TABLESWITCH:
            case Opcodes.LOOKUPSWITCH:
            default:
                break;
            case Opcodes.GOTO:
                if (insn instanceof JumpInsnNode) {
                    LabelNode jump = ((JumpInsnNode) insn).label;
                    blockid = cfg.getBasicBlockByLabel(jump.getLabel()).blockID;
                }
        }
        return blockid;
    }

}
